package com.sunyard.utils.sm2;

import com.sunyard.utils.alipayutil.AliPayUtil;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.generators.ECKeyPairGenerator;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECKeyGenerationParameters;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.crypto.util.Pack;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECFieldElement;
import org.bouncycastle.math.ec.ECPoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.security.SecureRandom;

/**
 * Created by puke on 2019/11/27
 */

public class SM3 {

    private static final Logger logger = LoggerFactory.getLogger(SM3.class);

    private static final String CHARSET_UTF_8 = "UTF-8";

    private final BigInteger eccP;
    private final BigInteger eccA;
    private final BigInteger eccB;
    private final BigInteger eccN;
    private final BigInteger eccXG;
    private final BigInteger eccYG;

    private final ECCurve eccCurve;
    private final ECPoint eccPointG;
    private final int byteLen;

    public final ECDomainParameters ecc_bc_spec;
    public final ECKeyPairGenerator ecc_key_pair_generator;
    private int _byteLen;

    public static final SM3 sm3 = new SM3();

    public static SM3 getInstance() {
        return sm3;
    }

    /**
     * 构造一个由GM/T 0003.5-2012标准定义的SM2椭圆曲线密码算法实例
     */
    public SM3() {
        eccP = new BigInteger("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16);
        eccA = new BigInteger("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", 16);
        eccB = new BigInteger("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16);
        eccN = new BigInteger("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16);
        eccXG = new BigInteger("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16);
        eccYG = new BigInteger("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16);

        eccCurve = new ECCurve.Fp(eccP, eccA, eccB);
        ECFieldElement eccGxFieldelement = new ECFieldElement.Fp(eccP, eccXG);
        ECFieldElement eccGyFieldelement = new ECFieldElement.Fp(eccP, eccYG);
        eccPointG = new ECPoint.Fp(eccCurve, eccGxFieldelement, eccGyFieldelement);
        byteLen = (int) Math.ceil(eccCurve.getFieldSize() / 8.0d);

        ecc_bc_spec = new ECDomainParameters(eccCurve, eccPointG, eccN);
        ECKeyGenerationParameters ecc_ecgenparam;
        ecc_ecgenparam = new ECKeyGenerationParameters(ecc_bc_spec,
                new SecureRandom());

        ecc_key_pair_generator = new ECKeyPairGenerator();
        ecc_key_pair_generator.init(ecc_ecgenparam);

        _byteLen = (int) Math.ceil(eccCurve.getFieldSize() / 8.0d);
    }

    /**
     * 使用SM2算法计算摘要(使用UTF-8编码)
     *
     * @param rawDataStr 原字符串
     * @param usrId      用户标识
     * @param pubKeyHex  公钥十六进制字符串
     * @return
     * @throws UnsupportedEncodingException
     * @since
     */
    public byte[] digest(String rawDataStr, String usrId, String pubKeyHex) throws UnsupportedEncodingException {
        AssertUtil.strIsNotBlank(rawDataStr, "rawDataStr is blank.");
        AssertUtil.strIsNotBlank(usrId, "usrId is blank.");
        AssertUtil.strIsNotBlank(pubKeyHex, "pubKeyHex is blank.");

        if (logger.isDebugEnabled()) {
            logger.debug("待验签字符串：[" + rawDataStr + "]");
        }

        return digest(rawDataStr.getBytes(CHARSET_UTF_8), usrId.getBytes(CHARSET_UTF_8), pubKeyHex);
    }

    /**
     * 使用SM2算法计算摘要(编码非UTF-8时使用该方法)
     *
     * @param rawData   原字符串
     * @param usrId     用户标识
     * @param pubKeyHex 公钥十六进制字符串
     * @return
     * @since
     */
    public byte[] digest(byte[] rawData, byte[] usrId, String pubKeyHex) {
        AssertUtil.objIsNotNull(rawData, "rawDataStr is blank.");
        AssertUtil.objIsNotNull(usrId, "usrId is blank.");
        AssertUtil.strIsNotBlank(pubKeyHex, "pubKeyHex is blank.");

        byte[] za = getZA(usrId, this.buildPA(pubKeyHex));
        if (logger.isDebugEnabled()) {
            logger.debug("[ZA hex={}]", HexUtil.toHex(za));
        }
        byte[] mLine = new byte[za.length + rawData.length];

        System.arraycopy(za, 0, mLine, 0, za.length);
        System.arraycopy(rawData, 0, mLine, za.length, rawData.length);

        SM3Digest sm3 = new SM3Digest();
        sm3.update(mLine, 0, mLine.length);

        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);
        return ebyte;
    }

    /**
     * 使用SM2算法进行签名验证
     *
     * @param rawDataStr 原字符串
     * @param usrId      用户标识
     * @param signHex    签名字符串
     * @param pubKeyHex  公钥十六进制字符串
     * @return
     * @throws UnsupportedEncodingException
     * @since
     */
    public boolean verify(String rawDataStr, String usrId, String signHex, String pubKeyHex) throws UnsupportedEncodingException {
        AssertUtil.strIsNotBlank(rawDataStr, "rawDataStr is blank.");
        AssertUtil.strIsNotBlank(usrId, "usrId is blank.");
        AssertUtil.strIsNotBlank(signHex, "signHex is blank.");
        AssertUtil.strIsNotBlank(pubKeyHex, "pubKeyHex is blank.");

        if (logger.isDebugEnabled()) {
            logger.debug("待验签字符串：[" + rawDataStr + "]");
            logger.debug("签名原文：[" + signHex + "]");
        }

        byte[] rawData = rawDataStr.getBytes(CHARSET_UTF_8);
        byte[] idaBytes = usrId.getBytes(CHARSET_UTF_8);

        BigInteger rSq = new BigInteger(signHex.substring(0, signHex.length() / 2), 16);
        if (rSq.compareTo(BigInteger.ONE) < 0 || rSq.compareTo(eccN) > 0) {
            return false;
        }
        BigInteger sSq = new BigInteger(signHex.substring(signHex.length() / 2), 16);
        if (sSq.compareTo(BigInteger.ONE) < 0 || sSq.compareTo(eccN) > 0) {
            return false;
        }

        ECPoint pa = buildPA(pubKeyHex);
        byte[] za = getZA(idaBytes, pa);
        byte[] mLineSq = new byte[za.length + rawData.length];

        System.arraycopy(za, 0, mLineSq, 0, za.length);
        System.arraycopy(rawData, 0, mLineSq, za.length, rawData.length);

        SM3Digest sm3 = new SM3Digest();
        sm3.update(mLineSq, 0, mLineSq.length);

        // e'
        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);
        BigInteger eSq = new BigInteger(1, ebyte);

        // t
        BigInteger t = rSq.add(sSq).mod(eccN);
        if (t.equals(BigInteger.ZERO)) {
            return false;
        }

        // x1'y1'
        ECPoint p0Sq = eccPointG.multiply(sSq);
        ECPoint p00Sq = pa.multiply(t);
        ECPoint p1Sq = p0Sq.add(p00Sq);

        // R
        BigInteger r = eSq.add(p1Sq.getX().toBigInteger()).mod(eccN);

        return (rSq.compareTo(r) == 0);
    }

    private ECPoint buildPA(String pubKeyHex) {
        String publicKeyXHex = pubKeyHex.substring(0, pubKeyHex.length() / 2);
        String publicKeyYHex = pubKeyHex.substring(pubKeyHex.length() / 2);
        if (logger.isDebugEnabled()) {
            logger.debug("[publicKeyXHex={}]", publicKeyXHex);
            logger.debug("[publicKeyYHex={}]", publicKeyYHex);
        }
        return getPoint(new BigInteger(publicKeyXHex, 16), new BigInteger(publicKeyYHex, 16));
    }

    /**
     * 将一个大整数转成一个指定长度的字节数组（前补零）
     *
     * @param bi 待转换的大整数
     * @return 转换后得到的字节数组
     */
    private byte[] bigIntegerToByteArray(BigInteger bi) {
        byte[] bibyte = bi.toByteArray();
        byte[] ubibyte;

        if (bibyte[0] == 0 && bibyte.length > byteLen) {
            ubibyte = new byte[bibyte.length - 1];
            System.arraycopy(bibyte, 1, ubibyte, 0, ubibyte.length);
        } else {
            ubibyte = bibyte;
        }

        if (ubibyte.length >= byteLen)
            return ubibyte;
        else {

            byte[] temp = new byte[byteLen];
            System.arraycopy(bibyte, 0, temp, byteLen - bibyte.length, bibyte.length);
            return temp;
        }
    }

    /**
     * 根据坐标获得椭圆曲线上的点
     *
     * @param x x坐标
     * @param y y坐标
     * @return 点
     */
    private ECPoint getPoint(BigInteger x, BigInteger y) {
        ECFieldElement eccGxFieldelement = new ECFieldElement.Fp(eccP, x);
        ;
        ECFieldElement eccGyFieldelement = new ECFieldElement.Fp(eccP, y);
        return new ECPoint.Fp(eccCurve, eccGxFieldelement, eccGyFieldelement);
    }

    /**
     * 获取IDA长度ENTL_A的字节数组
     *
     * @param idA 用户A标识的字节数组
     * @return ENTL_A的字节数组
     */
    private byte[] getEntlA(byte[] idA) {
        int t = idA.length * 8;
        byte[] entlAT = new byte[4];
        Pack.intToBigEndian(t, entlAT, 0);
        byte[] entlA = new byte[2];
        System.arraycopy(entlAT, 2, entlA, 0, 2);
        return entlA;
    }

    /**
     * 获取杂凑值ZA
     *
     * @param userId    用户标识
     * @param publicKey 公钥
     * @return ZA
     */
    private byte[] getZA(byte[] userId, ECPoint publicKey) {
        SM3Digest sm3 = new SM3Digest();
        byte[] p = null;

        // ENTLA
        p = getEntlA(userId);
        sm3.update(p, 0, p.length);

        // userId
        sm3.update(userId, 0, userId.length);

        // a,b
        p = bigIntegerToByteArray(eccA);
        sm3.update(p, 0, p.length);
        p = bigIntegerToByteArray(eccB);
        sm3.update(p, 0, p.length);

        // xG,yG
        p = bigIntegerToByteArray(eccXG);
        sm3.update(p, 0, p.length);
        p = bigIntegerToByteArray(eccYG);
        sm3.update(p, 0, p.length);

        // xA,yA
        p = bigIntegerToByteArray(publicKey.getX().toBigInteger());
        sm3.update(p, 0, p.length);
        p = bigIntegerToByteArray(publicKey.getY().toBigInteger());
        sm3.update(p, 0, p.length);

        // Z
        byte[] m = new byte[sm3.getDigestSize()];
        sm3.doFinal(m, 0);

        return m;
    }


    public String SignWithPublicKey(byte[] userId, byte[] M, String priKey, String pubKeyHex, BigInteger k) {

        BigInteger privateKey = new BigInteger(priKey);
        SM2_IntermediateResult sm2IRet = new SM2_IntermediateResult();
        SM2_Result sm2Ret = new SM2_Result();


        byte[] ZA = getZA(userId, this.buildPA(pubKeyHex));

        // byte[] ZA = GetZA(userId, pubKey);
        byte[] m_Line = new byte[ZA.length + M.length];

        System.arraycopy(ZA, 0, m_Line, 0, ZA.length);
        System.arraycopy(M, 0, m_Line, ZA.length, M.length);
        sm2IRet._MLine = m_Line;

        SM3Digest sm3 = new SM3Digest();
        sm3.update(m_Line, 0, m_Line.length);

        // e
        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);
        AliPayUtil.getLogger().info("使用SM2公钥生成摘要：" + HexUtil.toHex(ebyte));

        BigInteger e = new BigInteger(1, ebyte);
        sm2IRet._e = e;

        ECPoint kp = null;
        BigInteger r = null;
        BigInteger s = null;

        do {
            do {
                if (k == null || BigInteger.ZERO.equals(k)) {
                    AsymmetricCipherKeyPair keypair = ecc_key_pair_generator.generateKeyPair();
                    ECPrivateKeyParameters ecpriv = (ECPrivateKeyParameters) keypair.getPrivate();
                    ECPublicKeyParameters ecpub = (ECPublicKeyParameters) keypair.getPublic();
                    k = ecpriv.getD();
                    kp = ecpub.getQ();
                } else {
                    kp = eccPointG.multiply(k);
                }
                sm2IRet._p1 = kp;

                // r
                r = e.add(kp.getX().toBigInteger());
                r = r.mod(eccN);
            } while (r.equals(BigInteger.ZERO) || r.add(k).equals(eccN));

            // 1/(1 + dA)
            BigInteger da_1 = privateKey.add(BigInteger.ONE);
            da_1 = da_1.modInverse(eccN);
            sm2IRet._OneAddda_1 = da_1;

            // s
            s = r.multiply(privateKey);
            s = k.subtract(s).mod(eccN);
            s = da_1.multiply(s).mod(eccN);
        } while (s.equals(BigInteger.ZERO));

        sm2Ret.r = r;
        sm2Ret.s = s;


        AliPayUtil.getLogger().info("得到签名(r,s)：");
        AliPayUtil.getLogger().info("BigInteger r: "
                + r);
        AliPayUtil.getLogger().info("BigInteger s: "
                + s);
        AliPayUtil.getLogger().info("BigIntegerToByteArray r: "
                + BigIntegerToByteArray(r).length);
        AliPayUtil.getLogger().info("BigIntegerToByteArray s: "
                + BigIntegerToByteArray(s).length);
        AliPayUtil.getLogger().info("r: "
                + HexUtil.toHex(BigIntegerToByteArray(r)));
        AliPayUtil.getLogger().info("s: "
                + HexUtil.toHex(BigIntegerToByteArray(s)));

        return HexUtil.toHex(BigIntegerToByteArray(r))
                + HexUtil.toHex(BigIntegerToByteArray(s));

    }

    /**
     * 将一个大整数转成一个指定长度的字节数组（前补零）
     *
     * @param bi 待转换的大整数
     * @return 转换后得到的字节数组
     */
    public byte[] BigIntegerToByteArray(BigInteger bi) {
        byte[] bibyte = bi.toByteArray();
        byte[] ubibyte;

        if (bibyte[0] == 0 && bibyte.length > _byteLen) {
            ubibyte = new byte[bibyte.length - 1];
            System.arraycopy(bibyte, 1, ubibyte, 0, ubibyte.length);
        } else
            ubibyte = bibyte;

        if (ubibyte.length >= _byteLen)
            return ubibyte;
        else {

            byte[] temp = new byte[_byteLen];
            System.arraycopy(bibyte, 0, temp, _byteLen - bibyte.length,
                    bibyte.length);
            return temp;
        }
    }


    /**
     * 根据私钥d获取公钥P
     *
     * @param d 私钥d
     * @return 公钥P
     */
    public ECPoint GetPublicKey(BigInteger d) {
        return eccPointG.multiply(d);
    }

    /**
     * 使用SM2算法进行签名
     *
     * @param userId     用户标识
     * @param M          待签名的消息
     * @param privateKey 用户的私钥
     * @param k          随机数k（在正常使用时传入null即可）
     * @return 签名结果
     */
    public String Sign(byte[] userId, byte[] M, BigInteger privateKey,
                       BigInteger k) {
        SM2_IntermediateResult sm2IRet = new SM2_IntermediateResult();
        SM2_Result sm2Ret = new SM2_Result();

        byte[] ZA = getZA(userId, GetPublicKey(privateKey));
        byte[] m_Line = new byte[ZA.length + M.length];

        System.arraycopy(ZA, 0, m_Line, 0, ZA.length);
        System.arraycopy(M, 0, m_Line, ZA.length, M.length);
        sm2IRet._MLine = m_Line;

        SM3Digest sm3 = new SM3Digest();
        sm3.update(m_Line, 0, m_Line.length);

        // e
        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);

        System.out.println("使用私钥计算公钥生成摘要：" + HexUtil.toHex(ebyte));
        BigInteger e = new BigInteger(1, ebyte);
        sm2IRet._e = e;

        ECPoint kp = null;
        BigInteger r = null;
        BigInteger s = null;

        do {
            do {
                if (k == null || BigInteger.ZERO.equals(k)) {
                    AsymmetricCipherKeyPair keypair = ecc_key_pair_generator
                            .generateKeyPair();
                    ECPrivateKeyParameters ecpriv = (ECPrivateKeyParameters) keypair
                            .getPrivate();
                    ECPublicKeyParameters ecpub = (ECPublicKeyParameters) keypair
                            .getPublic();
                    k = ecpriv.getD();
                    kp = ecpub.getQ();
                } else {
                    kp = eccPointG.multiply(k);
                }
                sm2IRet._p1 = kp;

                // r
                r = e.add(kp.getX().toBigInteger());
                r = r.mod(eccN);
            } while (r.equals(BigInteger.ZERO) || r.add(k).equals(eccN));

            // 1/(1 + dA)
            BigInteger da_1 = privateKey.add(BigInteger.ONE);
            da_1 = da_1.modInverse(eccN);
            sm2IRet._OneAddda_1 = da_1;

            // s
            s = r.multiply(privateKey);
            s = k.subtract(s).mod(eccN);
            s = da_1.multiply(s).mod(eccN);
        } while (s.equals(BigInteger.ZERO));

        sm2Ret.r = r;
        sm2Ret.s = s;


        AliPayUtil.getLogger().info("得到签名(r,s)：");
        AliPayUtil.getLogger().info("r: "
                + HexUtil.toHex(BigIntegerToByteArray(r)));
        AliPayUtil.getLogger().info("s: "
                + HexUtil.toHex(BigIntegerToByteArray(s)));

        return HexUtil.toHex(BigIntegerToByteArray(r))
                + HexUtil.toHex(BigIntegerToByteArray(s));

        //  return sm2Ret;
    }

}